# Copyright 2019-2020 QuantumBlack Visual Analytics Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND
# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS
# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo
# (either separately or in combination, "QuantumBlack Trademarks") are
# trademarks of QuantumBlack. The License does not grant you any right or
# license to the QuantumBlack Trademarks. You may not use the QuantumBlack
# Trademarks or any confusingly similar mark as a trademark for your product,
#     or use the QuantumBlack Trademarks in any other manner that might cause
# confusion in the marketplace, including but not limited to in advertising,
# on websites, or on software.
#
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest

from causalnex.discretiser.discretiser_strategy import (
    DecisionTreeSupervisedDiscretiserMethod,
)


class TestDecisionTree:
    def test_single_continuous(self, continuous_data):
        diabete = continuous_data.copy(deep=True)

        dt_single = DecisionTreeSupervisedDiscretiserMethod(
            tree_params={"max_depth": 2},
            mode="single",
        )
        tree_discretiser = dt_single.fit(
            feat_names=["s6"],
            dataframe=diabete,
            target_continuous=True,
            target="target",
        )
        discretiser_output = tree_discretiser.transform(diabete[["s6"]]).values

        ground_truth = np.array(
            [
                [1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
                [1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 3, 0, 0],
                [1, 2, 0, 2, 1, 0, 1, 1, 0, 1, 1, 1, 1],
                [1, 2, 1, 0, 2, 1, 0, 0, 0, 1, 1, 1, 1],
                [1, 2, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1],
                [1, 1, 1, 1, 1, 2, 2, 1, 1, 2, 1, 1, 1],
                [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1],
                [1, 1, 0, 0, 1, 1, 2, 0, 0, 1, 2, 1, 1],
                [0, 1, 0, 1, 2, 2, 1, 0, 1, 2, 1, 1, 1],
                [3, 2, 1, 1, 1, 2, 2, 1, 1, 0, 0, 0, 2],
                [2, 0, 1, 1, 0, 2, 1, 1, 2, 1, 1, 3, 1],
                [0, 1, 0, 1, 2, 1, 0, 1, 1, 2, 1, 1, 2],
                [0, 1, 0, 1, 0, 2, 1, 2, 1, 1, 0, 2, 3],
                [1, 1, 0, 2, 0, 1, 1, 1, 1, 1, 1, 1, 0],
                [1, 2, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1],
                [1, 1, 2, 0, 1, 2, 1, 1, 1, 2, 1, 1, 2],
                [2, 1, 0, 1, 1, 1, 0, 2, 2, 2, 1, 1, 0],
                [0, 1, 0, 1, 1, 0, 1, 0, 1, 2, 1, 1, 0],
                [2, 2, 1, 1, 1, 2, 1, 2, 0, 0, 1, 0, 0],
                [0, 2, 1, 2, 2, 1, 2, 2, 1, 1, 0, 2, 1],
                [1, 1, 1, 1, 1, 2, 1, 1, 2, 0, 1, 1, 0],
                [2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1],
                [1, 1, 1, 0, 2, 0, 0, 1, 0, 1, 0, 0, 1],
                [1, 2, 2, 1, 1, 1, 0, 1, 2, 0, 2, 1, 1],
                [1, 1, 0, 1, 2, 2, 1, 2, 1, 2, 2, 2, 1],
                [2, 1, 1, 1, 2, 1, 1, 2, 1, 0, 0, 2, 1],
                [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 3],
                [0, 0, 1, 2, 1, 0, 0, 0, 2, 2, 1, 2, 1],
                [2, 2, 2, 1, 2, 1, 0, 1, 1, 1, 0, 1, 1],
                [0, 0, 1, 1, 0, 2, 1, 0, 1, 1, 0, 0, 0],
                [2, 0, 2, 1, 1, 0, 0, 1, 0, 1, 1, 2, 2],
                [2, 1, 1, 0, 1, 2, 1, 0, 1, 2, 1, 1, 1],
                [2, 1, 1, 1, 2, 1, 1, 0, 1, 0, 2, 1, 3],
                [0, 1, 1, 2, 1, 1, 0, 0, 1, 2, 1, 1, 1],
            ]
        )  # ground truth is generated by manually use DecionTree to extract thresholds

        assert (ground_truth == discretiser_output.reshape(-1, 13)).all()

    def test_single_categorical(self, categorical_data):
        df = categorical_data.copy(deep=True)
        ground_truth = np.array(
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
                [2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1],
                [2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 1],
                [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
            ]
        )  # ground truth is generated by manually use DecionTree to extract thresholds

        dt_single = DecisionTreeSupervisedDiscretiserMethod(
            tree_params={"max_depth": 2},
            mode="single",
        )
        tree_discretiser = dt_single.fit(
            feat_names=["petal width (cm)"],
            dataframe=df,
            target_continuous=False,
            target="target",
        )
        discretiser_output = tree_discretiser.transform(df[["petal width (cm)"]]).values
        assert (ground_truth == discretiser_output.reshape(-1, 15)).all()

    def test_invalid_mode(self):
        with pytest.raises(KeyError):
            DecisionTreeSupervisedDiscretiserMethod(
                tree_params={"max_depth": 2}, mode="invalid"
            )

    def test_transform_no_feature(self, get_iris_data, caplog):
        ground_truth = get_iris_data[["sepal width (cm)"]]
        dt_multi = DecisionTreeSupervisedDiscretiserMethod(
            mode="multi",
            split_unselected_feat=False,
            tree_params={"max_depth": 3, "random_state": 2020},
        )
        tree_discretiser = dt_multi.fit(
            feat_names=["sepal length (cm)", "petal length (cm)"],
            dataframe=get_iris_data,
            target_continuous=False,
            target="target",
        )

        output = tree_discretiser.transform(get_iris_data[["sepal width (cm)"]])

        assert "The column is left unchanged" in caplog.text
        assert all(ground_truth == output)

    def test_keep_unselected_feature(self, get_iris_data):
        ground_truth = np.array(
            [
                [4, 2, 3, 3, 4, 6, 4, 4, 1, 3, 4, 4, 2, 2, 6],
                [6, 6, 4, 5, 5, 4, 4, 4, 3, 4, 2, 4, 4, 4, 3],
                [3, 4, 6, 6, 3, 3, 4, 4, 2, 4, 4, 0, 3, 4, 5],
                [2, 5, 3, 4, 3, 3, 3, 3, 0, 1, 1, 3, 0, 1, 1],
                [0, 2, 0, 1, 1, 3, 2, 1, 0, 1, 3, 1, 1, 1, 1],
                [2, 1, 2, 1, 1, 0, 0, 1, 1, 2, 4, 3, 0, 2, 1],
                [1, 2, 1, 0, 1, 2, 1, 1, 1, 1, 3, 1, 2, 1, 2],
                [2, 1, 1, 1, 4, 3, 1, 2, 1, 1, 3, 2, 5, 1, 0],
                [3, 1, 1, 1, 3, 3, 1, 2, 1, 2, 1, 5, 1, 1, 1],
                [2, 4, 3, 2, 3, 3, 3, 1, 3, 3, 2, 1, 2, 4, 2],
            ]
        )  # ground truth is generated by manually use DecionTree to extract thresholds

        dt_multi = DecisionTreeSupervisedDiscretiserMethod(
            tree_params={"max_depth": 3, "random_state": 2020},
            mode="multi",
            split_unselected_feat=True,
        )
        tree_discretiser = dt_multi.fit(
            feat_names=[
                "sepal length (cm)",
                "sepal width (cm)",
                "petal length (cm)",
                "petal width (cm)",
            ],
            dataframe=get_iris_data,
            target_continuous=False,
            target="target",
        )
        output = tree_discretiser.transform(get_iris_data[["sepal width (cm)"]]).values
        assert (ground_truth == output.reshape(-1, 15)).all()

    def test_multi_fit(self, get_iris_data):
        ground_truth_petal_length = np.array(
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
                [0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2],
                [2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
                [2, 1, 2, 1, 2, 2, 0, 1, 2, 2, 2, 2, 2, 2, 2],
                [2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
            ]
        )  # ground truth is generated by manually use DecionTree to extract thresholds

        ground_truth_petal_width = np.array(
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
                [2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1],
                [2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 1],
                [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
            ]
        )  # ground truth is generated by manually use DecionTree to extract thresholds

        iris = get_iris_data.copy(deep=True)

        dt_multi = DecisionTreeSupervisedDiscretiserMethod(
            tree_params={"max_depth": 3, "random_state": 2020}, mode="multi"
        )

        tree_discretiser = dt_multi.fit(
            feat_names=[
                "sepal length (cm)",
                "sepal width (cm)",
                "petal length (cm)",
                "petal width (cm)",
            ],
            dataframe=iris,
            target_continuous=False,
            target="target",
        )

        discretiser_petal_length = tree_discretiser.transform(
            iris[["petal length (cm)"]]
        ).values
        discretiser_petal_width = tree_discretiser.transform(
            iris[["petal width (cm)"]]
        ).values
        assert (
            ground_truth_petal_length == discretiser_petal_length.reshape(-1, 15)
        ).all()
        assert (
            ground_truth_petal_width == discretiser_petal_width.reshape(-1, 15)
        ).all()

    def test_no_unselected_feature(self, get_iris_data):
        ground_truth = get_iris_data[["sepal width (cm)"]]
        dt_multi = DecisionTreeSupervisedDiscretiserMethod(
            tree_params={"max_depth": 3, "random_state": 2020},
            mode="multi",
            split_unselected_feat=False,
        )
        tree_discretiser = dt_multi.fit(
            feat_names=[
                "sepal length (cm)",
                "sepal width (cm)",
                "petal length (cm)",
                "petal width (cm)",
            ],
            dataframe=get_iris_data,
            target_continuous=False,
            target="target",
        )
        output = tree_discretiser.transform(get_iris_data[["sepal width (cm)"]])

        assert all(ground_truth == output)

    def test_default_args(self):
        dt_multi = DecisionTreeSupervisedDiscretiserMethod()
        params = dt_multi.get_params()
        assert params["tree_params"]["max_depth"] == 2

    def test_transform_all_single(self, get_iris_data):
        data = get_iris_data.copy(deep=True)
        sepal_length = np.array(
            [
                [1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 2],
                [2, 1, 1, 2, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0],
                [0, 1, 1, 2, 1, 1, 2, 1, 0, 1, 1, 0, 0, 1, 1],
                [0, 1, 0, 1, 1, 3, 3, 3, 2, 3, 2, 3, 1, 3, 1],
                [1, 2, 2, 2, 2, 3, 2, 2, 3, 2, 2, 2, 3, 2, 3],
                [3, 3, 3, 2, 2, 2, 2, 2, 2, 1, 2, 3, 3, 2, 2],
                [2, 2, 2, 1, 2, 2, 2, 3, 1, 2, 3, 2, 3, 3, 3],
                [3, 1, 3, 3, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3, 2],
                [3, 2, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 2],
                [3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 2],
            ]
        )
        sepal_width = np.array(
            [
                [2, 1, 1, 1, 2, 3, 2, 2, 0, 1, 2, 2, 1, 1, 3],
                [3, 3, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 2, 2, 1],
                [1, 2, 3, 3, 1, 1, 2, 2, 1, 2, 2, 0, 1, 2, 2],
                [1, 2, 1, 2, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0],
                [0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0],
                [1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 2, 1, 0, 1, 0],
                [0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1],
                [1, 0, 0, 0, 2, 1, 0, 1, 0, 0, 1, 1, 2, 0, 0],
                [1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 2, 0, 0, 0],
                [1, 2, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 2, 1],
            ]
        )
        petal_length = np.array(
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1],
                [1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
                [2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
                [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
                [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
            ]
        )
        petal_width = np.array(
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
                [2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1],
                [2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 1],
                [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
            ]
        )

        dt_multi = DecisionTreeSupervisedDiscretiserMethod(
            mode="single", tree_params={"max_depth": 2, "random_state": 2020}
        )
        tree_discretiser = dt_multi.fit(
            feat_names=[
                "sepal length (cm)",
                "sepal width (cm)",
                "petal length (cm)",
                "petal width (cm)",
            ],
            dataframe=data,
            target="target",
            target_continuous=False,
        )
        output_df = tree_discretiser.transform(data)

        assert (
            output_df["sepal length (cm)"].values.reshape(-1, 15) == sepal_length
        ).all()
        assert (
            output_df["sepal width (cm)"].values.reshape(-1, 15) == sepal_width
        ).all()
        assert (
            output_df["petal length (cm)"].values.reshape(-1, 15) == petal_length
        ).all()
        assert (
            output_df["petal width (cm)"].values.reshape(-1, 15) == petal_width
        ).all()

    def test_transform_all_multi(self, get_iris_data):
        data = get_iris_data.copy(deep=True)
        sepal_length = data["sepal length (cm)"]
        sepal_width = data["sepal width (cm)"]
        petal_length = np.array(
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
                [0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2],
                [2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
                [2, 1, 2, 1, 2, 2, 0, 1, 2, 2, 2, 2, 2, 2, 2],
                [2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
            ]
        )
        petal_width = np.array(
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
                [2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1],
                [2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 1],
                [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
            ]
        )

        dt_multi = DecisionTreeSupervisedDiscretiserMethod(
            mode="multi", tree_params={"max_depth": 3, "random_state": 2020}
        )
        tree_discretiser = dt_multi.fit(
            feat_names=[
                "sepal length (cm)",
                "sepal width (cm)",
                "petal length (cm)",
                "petal width (cm)",
            ],
            dataframe=data,
            target="target",
            target_continuous=False,
        )
        output_df = tree_discretiser.transform(data)

        assert (output_df["sepal length (cm)"].values == sepal_length).all()
        assert (output_df["sepal width (cm)"].values == sepal_width).all()
        assert (
            output_df["petal length (cm)"].values.reshape(-1, 15) == petal_length
        ).all()
        assert (
            output_df["petal width (cm)"].values.reshape(-1, 15) == petal_width
        ).all()

    def test_transform_shuffled_indices(self, get_iris_data):
        data = get_iris_data.copy(deep=True)
        ground_truth = np.array(
            [
                [0, 0, 1, 0, 1],
                [0, 0, 0, 0, 0],
                [0, 1, 1, 2, 1],
                [2, 1, 1, 0, 1],
                [1, 2, 1, 0, 2],
                [2, 2, 2, 0, 0],
                [1, 1, 0, 2, 2],
                [1, 2, 0, 2, 2],
                [2, 2, 0, 1, 2],
                [1, 1, 2, 1, 1],
            ]
        )

        sample_data = data[["sepal length (cm)", "target"]].sample(
            50, random_state=2021
        )
        dt_single = DecisionTreeSupervisedDiscretiserMethod(
            tree_params={"max_depth": 2},
            mode="single",
        )
        dt_single.fit(
            feat_names=["sepal length (cm)"],
            dataframe=sample_data,
            target_continuous=False,
            target="target",
        )
        discretiser_output = dt_single.transform(sample_data[["sepal length (cm)"]])
        assert discretiser_output.isnull().values.sum() == 0
        assert (ground_truth == discretiser_output.values.reshape(10, 5)).all()
